#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################
# Some parts of the code are borrowed from: https://github.com/pytorch/vision
# with the following license:
#
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import torch
import numpy as np
from ... import xnn
from .utils import *
from .mobilenetv2 import MobileNetV2TV, get_config as get_config_mnetv2


__all__ = ['MobileNetV2TVDenseNAS', 'MobileNetV2TVDenseNASBase', 'mobilenet_v2_tv_dense_nas']


###################################################
def get_config_mnetv2_dense_nas():
    model_config = xnn.utils.ConfigNode()
    model_config.input_channels = 3
    model_config.num_classes = 1000
    model_config.width_mult = 1.
    model_config.expand_ratio = 6
    model_config.output_stride = 32
    model_config.activation = xnn.layers.DefaultAct2d
    model_config.use_blocks = False
    model_config.kernel_size = 3
    model_config.dropout = False
    model_config.linear_dw = False
    model_config.layer_setting = None
    model_config.num_connection = 4
    model_config.group_size_dws = 1

    return model_config


###################################################
def define_nw(nw_type = 'mobv2_orig', s=None, model_config=None): 
    group_size_dws = model_config.group_size_dws
    mobv2_orig = [
       #intermediate_c, c, n, s, group_size_dws, group_pw, group_lin, el_add      
       #group_size_dws: Group size for depth wise separable
       #group_pw: Group size for pw 1x1 layer
       #group_lin: Group size for lin 1x1 layer
       #el_add: true:use element wise addition if it is free, false:don't use
       [32, 32,   1, s[0], 1, 1, 1, 1],                 #blk_1
       [32, 16,   1, 1,    1, 1, 1, 1],                 #blk_2_1
       [96, 24,   1, s[1], 1, 1, 1, 1],                 #blk_3_1
       [144, 24,  1, 1,    1, 1, 1, 1],                 #blk_3_2
       [144, 32,  1, s[2], 1, 1, 1, 1],                 #blk_4_1
       [192, 32,  2, 1,    1, 1, 1, 1],                 #4_2, 4_3
       [192, 64,  1, s[3], 1, 1, 1, 1],                 #blk_5_1
       [384, 64,  3, 1,    1, 1, 1, 1],                 #5_2, 5_3, 5_4
       [384, 96,  1, 1,    1, 1, 1, 1],                 #blk_6_1
       [576, 96,  2, 1,    1, 1, 1, 1],                 #blk_6_2, blk_6_3
       [576, 160, 1, s[4], 1, 1, 1, 1],                 #blk_7_1
       [960, 160, 2, 1,    1, 1, 1, 1],                 #blk_7_2, blk_7_3
       [960, 320, 1, 1,    1, 1, 1, 1],                 #blk8_1
       [320, 1280,1, 1,    1, 1, 1, 1],
     ]

    #dense-nas mobv2
    dense_nas_mobv2 = dict()
    dense_nas_mobv2['chs'] =     [28, 63, 63, 63, 63, 63, 63, 63, 63, 126, 126, 126, 189, 189, 315, 378, 378]
    #dense_nas_mobv2['chs']=     [32, 24, 32, 40, 48, 56, 64, 72, 96, 112, 128, 160, 176, 192, 320, 352, 384]
    dense_nas_mobv2['fm_sizes']= [112,56, 28, 28, 28, 14, 14, 14, 14,  14,  14,   7,   7,   7,   7,   7,   7]
    dense_nas_mobv2['stage']=      [0, 1,  2,  2,  2,  3,  3,  3,  4,   4,   4,   5,   5,   5,   6,   6,   6, 7]
    dense_nas_mobv2['num_layers']= [0, 3,  3,  3,  3,  3,  3,  3,  3,   3,   3,   3,   3,   3,   0,   0,   0]
    dense_nas_mobv2['last_ch'] = 1984 #1280
    
    nw_config = dense_nas_mobv2

    return nw_config 


###################################################
class InvertedResidual(torch.nn.Module):
    def __init__(self, inp, oup, stride, hidden_dim, activation, kernel_size, 
                group_size_dws, group_pw, group_lin, el_add, linear_dw, en_use_res_connect=True):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.use_res_connect = self.stride == 1 and inp == oup and en_use_res_connect

        activation_dw = (False if linear_dw else activation)
        groups_dw = hidden_dim//group_size_dws

        en_bn_relu = False
        if en_bn_relu:
            normalization = True
        else:
            activation = False
            activation_dw = False
            normalization = False

        layers = []
        if hidden_dim != inp:
            # pw
            layers.append(xnn.layers.ConvNormAct2d(inp, hidden_dim, kernel_size=1, activation=activation, groups=group_pw, normalization=normalization))

        if en_bn_relu:
            layers.extend([
                # dw
                xnn.layers.ConvNormAct2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride, activation=activation_dw, groups=groups_dw, normalization=normalization),
                # pw-linear
                torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False, groups=group_lin),
                xnn.layers.DefaultNorm2d(oup),
            ])
        else: 
            layers.extend([
                # dw
                xnn.layers.ConvNormAct2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride, activation=activation_dw, groups=groups_dw, normalization=normalization),
                # pw-linear
                torch.nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False, groups=group_lin),
            ])   

        if linear_dw:
            layers.append(activation(inplace=True))

        self.conv = torch.nn.Sequential(*layers)

        if self.use_res_connect:
            self.add = xnn.layers.AddBlock(signed=True)

    def forward(self, x):
        x1 = self.conv(x)
        if self.use_res_connect:
            x1 = self.add((x, x1))

        return x1


###################################################
class MobileNetV2TVDenseNASBase(torch.nn.Module):
    def __init__(self, ResidualBlock, model_config):
        super().__init__()
        self.model_config = model_config
        self.num_classes = self.model_config.num_classes
        strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)

        if self.model_config.layer_setting is None:
            self.model_config.layer_setting = define_nw(nw_type='mobv2_for_latency_measure', s=strides, model_config = self.model_config)

        activation = self.model_config.activation
        width_mult = self.model_config.width_mult
        linear_dw = self.model_config.linear_dw
        kernel_size = self.model_config.kernel_size
        expansion_ratio = self.model_config.expand_ratio
        # building first layer
        input_channel = self.model_config.layer_setting['chs'][0]
        features = [xnn.layers.ConvNormAct2d(3, input_channel, kernel_size=kernel_size, stride=2, activation=activation)]
        
        ##### MBConv Blocks #####
        op_chs = self.model_config.layer_setting['chs']
        fm_sizes = self.model_config.layer_setting['fm_sizes']
        group_size_dws = model_config.group_size_dws
        group_pw = 1
        group_lin = 1
        el_add  = False
        num_connection = self.model_config.num_connection
        linear_dw = model_config.linear_dw

        output_channel = input_channel
        intermediate_c = input_channel
        stride = 1
        features.append(ResidualBlock(input_channel, output_channel, stride, intermediate_c, activation, kernel_size,
                            group_size_dws, group_pw, group_lin, el_add, linear_dw, en_use_res_connect=False))

        self.branch_block = [None] * len(op_chs)
        self.trunc_block = [None] * len(op_chs)
        self.repeat_block = [None] * len(op_chs)

        for blk_idx in range(1,len(op_chs)):
            print("="*32, " ", blk_idx, " ", "="*32,)
            stride = fm_sizes[blk_idx-1] // fm_sizes[blk_idx]
            intermediate_c = input_channel * expansion_ratio
            output_channel = op_chs[blk_idx]
            self.trunc_block[blk_idx] = ResidualBlock(input_channel, output_channel, stride, intermediate_c, activation, kernel_size,
                            group_size_dws, group_pw, group_lin, el_add, linear_dw, en_use_res_connect=False)

            if blk_idx > 2:
                cur_branch_block_ops = []
                print("src range: ", range(max(blk_idx-num_connection,1), blk_idx-1))
                for src_idx in range(max(blk_idx-num_connection,1), blk_idx-1):
                    input_channel = op_chs[src_idx]
                    stride = fm_sizes[src_idx] // fm_sizes[blk_idx]
                    # if stride is more than 2 then do not connect those src to dst
                    if (stride == 2) or (stride == 1):
                        print("src to dst", src_idx, " , ", blk_idx)
                        intermediate_c = input_channel*expansion_ratio
                        cur_branch_block_ops.append(ResidualBlock(input_channel, output_channel, stride, intermediate_c, activation, kernel_size,
                                    group_size_dws, group_pw, group_lin, el_add, linear_dw, en_use_res_connect=False))
                self.branch_block[blk_idx] = torch.nn.Sequential(*cur_branch_block_ops)

            intermediate_c = output_channel * expansion_ratio
            stride = 1
            self.repeat_block[blk_idx] = ResidualBlock(output_channel, output_channel, stride, intermediate_c, activation, kernel_size,
                            group_size_dws, group_pw, group_lin, el_add, linear_dw, en_use_res_connect=False)
            input_channel = output_channel
            
        # building classifier
        if self.model_config.num_classes != None:
            #self.last_channel = 
            # building last several layers
            self.last_feature = xnn.layers.ConvNormAct2d(input_channel, self.model_config.layer_setting['last_ch'], kernel_size=1, activation=activation)

            self.classifier = torch.nn.Sequential(
                torch.nn.Dropout(0.2) if self.model_config.dropout else xnn.layers.BypassBlock(),
                torch.nn.Linear(self.model_config.layer_setting['last_ch'], self.num_classes),
            )

        # make it nn.Sequential
        self.features = torch.nn.Sequential(*features)
        self.trunc_block = torch.nn.ModuleList(self.trunc_block)
        self.branch_block = torch.nn.ModuleList(self.branch_block)
        self.repeat_block = torch.nn.ModuleList(self.repeat_block)

        # weight initialization
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.ones_(m.weight)
                torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, 0, 0.01)
                torch.nn.init.zeros_(m.bias)


    def forward(self, x):
        x = self.features[0](x)
        x = self.features[1](x)
        op_chs = self.model_config.layer_setting['chs']
        fm_sizes = self.model_config.layer_setting['fm_sizes']
        blk_data = [None] * len(op_chs)
        num_connection = self.model_config.num_connection
        for blk_idx in range(1,len(op_chs)):
            print("="*32, " ", blk_idx, " ", "="*32,)
            print("cur blk op ch ", op_chs[blk_idx])
            if blk_idx <= 2:
                x = self.trunc_block[blk_idx](x)
            else:    
                x = self.trunc_block[blk_idx](x)
                local_idx = 0

                for src_idx in range(max(blk_idx-num_connection,1), blk_idx-1):
                    stride = fm_sizes[src_idx] // fm_sizes[blk_idx]
                    # if stride is more than 2 then do not connect those src to dst
                    if (stride == 2) or (stride == 1):
                        print("src_idx ", src_idx)
                        print("src blk op ch ", op_chs[src_idx])
                        print("src data shape", blk_data[src_idx].shape)
                        x = x + self.branch_block[blk_idx][local_idx](blk_data[src_idx])
                        local_idx += 1
            x = self.repeat_block[blk_idx](x)
            blk_data[blk_idx] = x           

        x = self.last_feature(x)
        if self.num_classes is not None:
            xnn.utils.print_once('=> feature size is: ', x.size())
            x = torch.nn.functional.adaptive_avg_pool2d(x,(1,1))
            x = torch.flatten(x, 1)
            x = self.classifier(x)

        return x

class MobileNetV2TVDenseNAS(MobileNetV2TVDenseNASBase):
    def __init__(self, **kwargs):
        model_config = get_config_mnetv2_dense_nas()
        if 'model_config' in list(kwargs.keys()):
            model_config = model_config.merge_from(kwargs['model_config'])
        super().__init__(InvertedResidual, model_config)


#######################################################################
class mobilenet_v2_tv_dense_nas(MobileNetV2TVDenseNAS):
    def __init__(self, pretrained=False, **kwargs):
        super().__init__(**kwargs)
        if pretrained:
            self = xnn.utils.load_weights(self, pretrained)